{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# LSTM Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:21.946635Z",
     "iopub.status.busy": "2025-09-28T08:21:21.946416Z",
     "iopub.status.idle": "2025-09-28T08:21:26.147966Z",
     "shell.execute_reply": "2025-09-28T08:21:26.147143Z"
    }
   },
   "outputs": [],
   "source": [
    "from deep_river.regression import RollingRegressor\n",
    "from river import (\n",
    "    metrics,\n",
    "    compose,\n",
    "    preprocessing,\n",
    "    datasets,\n",
    "    stats,\n",
    "    feature_extraction,\n",
    ")\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:26.151131Z",
     "iopub.status.busy": "2025-09-28T08:21:26.150814Z",
     "iopub.status.idle": "2025-09-28T08:21:26.154122Z",
     "shell.execute_reply": "2025-09-28T08:21:26.153319Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "def get_hour(x):\n",
    "    x[\"hour\"] = x[\"moment\"].hour\n",
    "    return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Simple RNN Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:26.157612Z",
     "iopub.status.busy": "2025-09-28T08:21:26.157282Z",
     "iopub.status.idle": "2025-09-28T08:21:26.162653Z",
     "shell.execute_reply": "2025-09-28T08:21:26.161819Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "class RnnModule(nn.Module):\n",
    "\n",
    "    def __init__(self, n_features, hidden_size):\n",
    "        super().__init__()\n",
    "        self.n_features = n_features\n",
    "        self.rnn = nn.RNN(\n",
    "            input_size=n_features, hidden_size=hidden_size, num_layers=1\n",
    "        )\n",
    "        self.fc = nn.Linear(in_features=hidden_size, out_features=1)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        output, hn = self.rnn(X)  # lstm with input, hidden, and internal state\n",
    "        return self.fc(output[-1, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:26.165849Z",
     "iopub.status.busy": "2025-09-28T08:21:26.165517Z",
     "iopub.status.idle": "2025-09-28T08:21:27.130999Z",
     "shell.execute_reply": "2025-09-28T08:21:27.130265Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><div class=\"river-component river-union\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">['clouds', [...]</pre></summary><code class=\"river-estimator-params\">Select (\n",
       "  clouds\n",
       "  humidity\n",
       "  pressure\n",
       "  temperature\n",
       "  wind\n",
       ")\n",
       "</code></details><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">get_hour</pre></summary><code class=\"river-estimator-params\">\n",
       "def get_hour(x):\n",
       "    x[\"hour\"] = x[\"moment\"].hour\n",
       "    return x\n",
       "\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">y_mean_by_station_and_hour</pre></summary><code class=\"river-estimator-params\">TargetAgg (\n",
       "  by=['station', 'hour']\n",
       "  how=Mean ()\n",
       "  target_name=\"y\"\n",
       ")\n",
       "</code></details></div></div><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RollingRegressor</pre></summary><code class=\"river-estimator-params\">RollingRegressor (\n",
       "  module=RnnModule(\n",
       "  (rnn): RNN(10, 16)\n",
       "  (fc): Linear(in_features=16, out_features=1, bias=True)\n",
       ")\n",
       "  loss_fn=\"mse\"\n",
       "  optimizer_fn=\"sgd\"\n",
       "  lr=0.01\n",
       "  is_feature_incremental=False\n",
       "  device=\"cpu\"\n",
       "  seed=42\n",
       "  window_size=20\n",
       "  append_predict=True\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  TransformerUnion (\n",
       "    Select (\n",
       "      clouds\n",
       "      humidity\n",
       "      pressure\n",
       "      temperature\n",
       "      wind\n",
       "    ),\n",
       "    Pipeline (\n",
       "      FuncTransformer (\n",
       "        func=\"get_hour\"\n",
       "      ),\n",
       "      TargetAgg (\n",
       "        by=['station', 'hour']\n",
       "        how=Mean ()\n",
       "        target_name=\"y\"\n",
       "      )\n",
       "    )\n",
       "  ),\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  RollingRegressor (\n",
       "    module=RnnModule(\n",
       "    (rnn): RNN(10, 16)\n",
       "    (fc): Linear(in_features=16, out_features=1, bias=True)\n",
       "  )\n",
       "    loss_fn=\"mse\"\n",
       "    optimizer_fn=\"sgd\"\n",
       "    lr=0.01\n",
       "    is_feature_incremental=False\n",
       "    device=\"cpu\"\n",
       "    seed=42\n",
       "    window_size=20\n",
       "    append_predict=True\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = datasets.Bikes()\n",
    "metric = metrics.MAE()\n",
    "\n",
    "model_pipeline = compose.Select(\n",
    "    \"clouds\", \"humidity\", \"pressure\", \"temperature\", \"wind\"\n",
    ")\n",
    "model_pipeline += get_hour | feature_extraction.TargetAgg(\n",
    "    by=[\"station\", \"hour\"], how=stats.Mean()\n",
    ")\n",
    "model_pipeline |= preprocessing.StandardScaler()\n",
    "model_pipeline |= RollingRegressor(\n",
    "    module=RnnModule(10, 16),\n",
    "    loss_fn=\"mse\",\n",
    "    optimizer_fn=\"sgd\",\n",
    "    window_size=20,\n",
    "    lr=1e-2,\n",
    "    hidden_size=32,  # parameters of MyModule can be overwritten\n",
    "    append_predict=True,\n",
    ")\n",
    "model_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:27.134142Z",
     "iopub.status.busy": "2025-09-28T08:21:27.133809Z",
     "iopub.status.idle": "2025-09-28T08:21:37.418160Z",
     "shell.execute_reply": "2025-09-28T08:21:37.417601Z"
    },
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE: 3.37\n"
     ]
    }
   ],
   "source": [
    "for x, y in dataset.take(5000):\n",
    "    y_pred = model_pipeline.predict_one(x)\n",
    "    metric.update(y_true=y, y_pred=y_pred)\n",
    "    model_pipeline.learn_one(x=x, y=y)\n",
    "print(f\"MAE: {metric.get():.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## LSTM Regression Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:37.420640Z",
     "iopub.status.busy": "2025-09-28T08:21:37.420420Z",
     "iopub.status.idle": "2025-09-28T08:21:37.424759Z",
     "shell.execute_reply": "2025-09-28T08:21:37.424147Z"
    }
   },
   "outputs": [],
   "source": [
    "class LstmModule(nn.Module):\n",
    "\n",
    "    def __init__(self, n_features, hidden_size=1):\n",
    "        super().__init__()\n",
    "        self.n_features = n_features\n",
    "        self.hidden_size = hidden_size\n",
    "        self.lstm = nn.LSTM(\n",
    "            input_size=n_features,\n",
    "            hidden_size=hidden_size,\n",
    "            num_layers=1,\n",
    "            bidirectional=False,\n",
    "        )\n",
    "        self.fc = nn.Linear(in_features=hidden_size, out_features=1)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        output, (hn, cn) = self.lstm(\n",
    "            X\n",
    "        )  # lstm with input, hidden, and internal state\n",
    "        return self.fc(output[-1, :])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:37.426970Z",
     "iopub.status.busy": "2025-09-28T08:21:37.426762Z",
     "iopub.status.idle": "2025-09-28T08:21:37.435331Z",
     "shell.execute_reply": "2025-09-28T08:21:37.434620Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><div class=\"river-component river-union\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">['clouds', [...]</pre></summary><code class=\"river-estimator-params\">Select (\n",
       "  clouds\n",
       "  humidity\n",
       "  pressure\n",
       "  temperature\n",
       "  wind\n",
       ")\n",
       "</code></details><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">get_hour</pre></summary><code class=\"river-estimator-params\">\n",
       "def get_hour(x):\n",
       "    x[\"hour\"] = x[\"moment\"].hour\n",
       "    return x\n",
       "\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">y_mean_by_station_and_hour</pre></summary><code class=\"river-estimator-params\">TargetAgg (\n",
       "  by=['station', 'hour']\n",
       "  how=Mean ()\n",
       "  target_name=\"y\"\n",
       ")\n",
       "</code></details></div></div><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">RollingRegressor</pre></summary><code class=\"river-estimator-params\">RollingRegressor (\n",
       "  module=LstmModule(\n",
       "  (lstm): LSTM(10, 16)\n",
       "  (fc): Linear(in_features=16, out_features=1, bias=True)\n",
       ")\n",
       "  loss_fn=\"mse\"\n",
       "  optimizer_fn=\"sgd\"\n",
       "  lr=0.01\n",
       "  is_feature_incremental=False\n",
       "  device=\"cpu\"\n",
       "  seed=42\n",
       "  window_size=20\n",
       "  append_predict=True\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  TransformerUnion (\n",
       "    Select (\n",
       "      clouds\n",
       "      humidity\n",
       "      pressure\n",
       "      temperature\n",
       "      wind\n",
       "    ),\n",
       "    Pipeline (\n",
       "      FuncTransformer (\n",
       "        func=\"get_hour\"\n",
       "      ),\n",
       "      TargetAgg (\n",
       "        by=['station', 'hour']\n",
       "        how=Mean ()\n",
       "        target_name=\"y\"\n",
       "      )\n",
       "    )\n",
       "  ),\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  RollingRegressor (\n",
       "    module=LstmModule(\n",
       "    (lstm): LSTM(10, 16)\n",
       "    (fc): Linear(in_features=16, out_features=1, bias=True)\n",
       "  )\n",
       "    loss_fn=\"mse\"\n",
       "    optimizer_fn=\"sgd\"\n",
       "    lr=0.01\n",
       "    is_feature_incremental=False\n",
       "    device=\"cpu\"\n",
       "    seed=42\n",
       "    window_size=20\n",
       "    append_predict=True\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = datasets.Bikes()\n",
    "metric = metrics.MAE()\n",
    "\n",
    "model_pipeline = compose.Select(\n",
    "    \"clouds\", \"humidity\", \"pressure\", \"temperature\", \"wind\"\n",
    ")\n",
    "model_pipeline += get_hour | feature_extraction.TargetAgg(\n",
    "    by=[\"station\", \"hour\"], how=stats.Mean()\n",
    ")\n",
    "model_pipeline |= preprocessing.StandardScaler()\n",
    "model_pipeline |= RollingRegressor(\n",
    "    module=LstmModule(10, 16),\n",
    "    loss_fn=\"mse\",\n",
    "    optimizer_fn=\"sgd\",\n",
    "    window_size=20,\n",
    "    lr=1e-2,\n",
    "    hidden_size=32,  # parameters of MyModule can be overwritten\n",
    "    append_predict=True,\n",
    ")\n",
    "model_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-28T08:21:37.437511Z",
     "iopub.status.busy": "2025-09-28T08:21:37.437296Z",
     "iopub.status.idle": "2025-09-28T08:21:46.550822Z",
     "shell.execute_reply": "2025-09-28T08:21:46.550083Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAE: 2.80\n"
     ]
    }
   ],
   "source": [
    "for x, y in dataset.take(5000):\n",
    "    y_pred = model_pipeline.predict_one(x)\n",
    "    metric.update(y_true=y, y_pred=y_pred)\n",
    "    model_pipeline.learn_one(x=x, y=y)\n",
    "print(f\"MAE: {metric.get():.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep-river",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
